In [1]:
import os

## Set directory
os.chdir('/hpc/group/pbenfeylab/CheWei/CW_data/genesys')

import networkx as nx
from genesys_evaluate import *
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
In [2]:
## Conda Env pytorch-gpu on DCC
print(torch.__version__)
print(sc.__version__) 
1.13.0.post200
1.9.1
In [3]:
## Genes considered/used (shared among samples) 
gene_list = pd.read_csv('./gene_list_1108.csv')

Load mutant data¶

Load mutant scRNA-seq data (The .h5ad file is converted from Seurat Objects from GEO GSE152766)¶

In [4]:
shr = sc.read_h5ad("/hpc/group/pbenfeylab/CheWei/scRNA-seq/Integrated_Objects/shr_integrated.h5ad")
/hpc/group/pbenfeylab/ch416/miniconda3/envs/pytorch-gpu/lib/python3.8/site-packages/anndata/compat/__init__.py:232: FutureWarning: Moving element from .uns['neighbors']['distances'] to .obsp['distances'].

This is where adjacency matrices should go now.
  warn(
In [5]:
scr = sc.read_h5ad("/hpc/group/pbenfeylab/CheWei/scRNA-seq/Integrated_Objects/scr_integrated.h5ad")
/hpc/group/pbenfeylab/ch416/miniconda3/envs/pytorch-gpu/lib/python3.8/site-packages/anndata/compat/__init__.py:232: FutureWarning: Moving element from .uns['neighbors']['distances'] to .obsp['distances'].

This is where adjacency matrices should go now.
  warn(
In [6]:
# Floored at 0, Ceiling at 10 and normalize to range 0-1
scr.X[scr.X < 0]=0
mmin = np.amin(scr.X)
nor = (np.amax(scr.X)-mmin)
scr.X = (scr.X-mmin)/nor
############################
shr.X[shr.X < 0]=0
mmin = np.amin(shr.X)
nor = (np.amax(shr.X)-mmin)
shr.X = (shr.X-mmin)/nor
In [7]:
## Subset only the T0 cells and remove those annotated as elongated or mature
scr = scr[np.array(pd.Series(scr.obs['time.celltype.anno.Li.crude']).str.match("^Pro") & pd.Series(scr.obs['consensus.time.group']).str.match("T0")),:]
In [8]:
mi = match(np.array(gene_list['features']).tolist(),scr.var['features'].tolist())
mi = np.array(mi)

scrx = pd.DataFrame(scr.X.copy())
scrx[len(scrx.columns)] = pd.Series(0.0, index=np.arange(len(scrx)))
mi[np.where(mi==None)[0]] = len(scrx.columns)-1
scrx = scrx[pd.Series(mi)]
scrx.columns = np.array(gene_list['features']).tolist()

batch_size = 2000
idx = np.random.choice(range(len(scrx)),batch_size)
scrx = scrx.loc[idx]
In [9]:
## Randomly sampled 2000 scr stem cells
scrx
Out[9]:
AT1G05260 AT3G59370 AT2G36100 AT1G12080 AT1G12090 AT4G11290 AT5G42180 AT5G66390 AT2G32300 AT2G02130 ... AT4G06395 AT3G55440 AT3G03100 AT5G54760 AT2G33040 AT2G42680 AT5G11770 AT5G08290 AT5G53300 AT5G64400
66 0.0 0.0 0.000000 0.0 0.0 0.0 0.000000 0.0 0.0 0.0 ... 0.000000 0.007915 0.001323 0.000000 0.000000 0.000000 0.003263 0.000000 0.000000 0.000160
163 0.0 0.0 0.000000 0.0 0.0 0.0 0.000000 0.0 0.0 0.0 ... 0.025064 0.008866 0.000000 0.000000 0.000000 0.004550 0.000000 0.036358 0.000000 0.030074
93 0.0 0.0 0.000000 0.0 0.0 0.0 0.000000 0.0 0.0 0.0 ... 0.000000 0.047018 0.009379 0.000000 0.000000 0.000000 0.014504 0.005235 0.000000 0.049956
95 0.0 0.0 0.000000 0.0 0.0 0.0 0.000000 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.014331 0.000000 0.000000
109 0.0 0.0 0.000000 0.0 0.0 0.0 0.000000 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.104939 0.000000 0.001219 0.002244 0.025827 0.000000 0.000000
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
33 0.0 0.0 0.000000 0.0 0.0 0.0 0.000000 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.007477 0.000000 0.000000 0.000000 0.000000 0.000000
5 0.0 0.0 0.000000 0.0 0.0 0.0 0.000000 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.013880 0.000000 0.000000
32 0.0 0.0 0.014883 0.0 0.0 0.0 0.009129 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
9 0.0 0.0 0.000000 0.0 0.0 0.0 0.000000 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.090622 0.064728 0.000000
107 0.0 0.0 0.000000 0.0 0.0 0.0 0.000000 0.0 0.0 0.0 ... 0.000000 0.026854 0.000000 0.000313 0.025249 0.000000 0.000000 0.000000 0.012724 0.016613

2000 rows × 17513 columns

In [10]:
sum(scrx.sum(axis=1))
Out[10]:
1592340.0877264303
In [39]:
## Subset only the T0 cells and remove those annotated as elongated or mature
shr = shr[np.array(pd.Series(shr.obs['time.celltype.anno.Li.crude']).str.match("^Pro") & pd.Series(shr.obs['consensus.time.group']).str.match("T0")),:]
In [40]:
mi = match(np.array(gene_list['features']).tolist(),shr.var['features'].tolist())
mi = np.array(mi)

shrx = pd.DataFrame(shr.X.copy())
shrx[len(shrx.columns)] = pd.Series(0.0, index=np.arange(len(shrx)))
mi[np.where(mi==None)[0]] = len(shrx.columns)-1
shrx = shrx[pd.Series(mi)]
shrx.columns = np.array(gene_list['features']).tolist()

batch_size = 2000
idx = np.random.choice(range(len(shrx)),batch_size)
shrx = shrx.loc[idx]
In [41]:
## Randomly sampled 2000 shr stem cells
shrx
Out[41]:
AT1G05260 AT3G59370 AT2G36100 AT1G12080 AT1G12090 AT4G11290 AT5G42180 AT5G66390 AT2G32300 AT2G02130 ... AT4G06395 AT3G55440 AT3G03100 AT5G54760 AT2G33040 AT2G42680 AT5G11770 AT5G08290 AT5G53300 AT5G64400
338 0.0 0.0 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.012978 0.000000
229 0.0 0.0 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.000000 0.000000 0.006763 0.000000 0.005768 0.000000 0.028089 0.000000 0.000000
606 0.0 0.0 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.000000 0.002436 0.000000 0.000000 0.000000 0.000000 0.060249 0.000000 0.000000
373 0.0 0.0 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.000000 0.000000 0.007069 0.000000 0.000000 0.000000 0.185272 0.000000 0.000000
5 0.0 0.0 0.001678 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 ... 0.015282 0.000000 0.000000 0.000000 0.000000 0.000000 0.010126 0.000000 0.000000 0.000000
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
485 0.0 0.0 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.222418 ... 0.000000 0.045841 0.001403 0.000000 0.068373 0.000000 0.047298 0.000000 0.022758 0.000000
301 0.0 0.0 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.000000 0.002071 0.008318 0.000000 0.000000 0.000000 0.000000 0.032065 0.000000
251 0.0 0.0 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.000000 0.031146 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
197 0.0 0.0 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.075849 0.022099 0.005851 0.000000 0.024259 0.027292 0.000000 0.073137 0.015918
264 0.0 0.0 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.000000 ... 0.000000 0.114439 0.038295 0.020004 0.000000 0.040415 0.033262 0.055380 0.018655 0.030481

2000 rows × 17513 columns

In [42]:
sum(shrx.sum(axis=1))
Out[42]:
1397350.106048115
In [43]:
input_size = 17513
## 10 cell types 
output_size = 10
embedding_dim = 256
hidden_dim = 256
n_layers = 2
device = "cpu"
path = "./"

Load trained GeneSys model (Evaluate)¶

In [44]:
model = ClassifierLSTM(input_size, output_size, embedding_dim, hidden_dim, n_layers).to(device)
model.load_state_dict(torch.load(path+"best_ALL_1130_continue.pth", map_location=torch.device('cpu')))
model = model
model.eval()
Out[44]:
ClassifierLSTM(
  (fc1): Sequential(
    (0): Linear(in_features=17513, out_features=256, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): GaussianNoise()
  )
  (fc): Sequential(
    (0): ReLU()
    (1): Linear(in_features=512, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=10, bias=True)
  )
  (lstm): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (b_to_z): DBlock(
    (fc1): Linear(in_features=512, out_features=256, bias=True)
    (fc2): Linear(in_features=512, out_features=256, bias=True)
    (fc_mu): Linear(in_features=256, out_features=512, bias=True)
    (fc_logsigma): Linear(in_features=256, out_features=512, bias=True)
  )
  (bz2_infer_z1): DBlock(
    (fc1): Linear(in_features=1024, out_features=256, bias=True)
    (fc2): Linear(in_features=1024, out_features=256, bias=True)
    (fc_mu): Linear(in_features=256, out_features=512, bias=True)
    (fc_logsigma): Linear(in_features=256, out_features=512, bias=True)
  )
  (z1_to_z2): DBlock(
    (fc1): Linear(in_features=512, out_features=256, bias=True)
    (fc2): Linear(in_features=512, out_features=256, bias=True)
    (fc_mu): Linear(in_features=256, out_features=512, bias=True)
    (fc_logsigma): Linear(in_features=256, out_features=512, bias=True)
  )
  (z_to_x): Decoder(
    (fc1): Linear(in_features=512, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=256, bias=True)
    (fc3): Linear(in_features=256, out_features=17513, bias=True)
  )
)

Provide shr input data¶

In [45]:
classes = ['Columella', 'Lateral Root Cap', 'Phloem', 'Xylem', 'Procambium', 'Pericycle', 'Endodermis', 'Cortex', 'Atrichoblast', 'Trichoblast']
class2num = {c: i for (i, c) in enumerate(classes)}
num2class = {i: c for (i, c) in enumerate(classes)}
In [59]:
xm = torch.tensor(np.array(shrx), dtype=torch.float32)
x = torch.stack([xm,xm,xm,xm,xm,xm,xm,xm,xm,xm,xm],dim=1)
In [60]:
## Initialize hidden state
pred_h = model.init_hidden(batch_size)

# t0 and t1 prediction based on data from the first two time points
t0 = model.generate_current(x, pred_h, 0)
t1 = model.generate_next(x, pred_h, 1)
update_x = torch.stack([t0,t1,t1,t1,t1,t1,t1,t1,t1,t1,t1],dim=1)

## predict t1 label
y1, pred_h = model.predict_proba(update_x, pred_h, 1)
y1 = [num2class[i] for i in np.argmax(y1.cpu().detach().numpy(), axis=1)]

t2 = model.generate_next(update_x, pred_h, 2)
update_x = torch.stack([t0,t1,t2,t2,t2,t2,t2,t2,t2,t2,t2],dim=1)

y2, pred_h = model.predict_proba(update_x, pred_h, 2)
y2 = [num2class[i] for i in np.argmax(y2.cpu().detach().numpy(), axis=1)]

t3 = model.generate_next(update_x, pred_h, 3)
update_x = torch.stack([t0,t1,t2,t3,t3,t3,t3,t3,t3,t3,t3],dim=1)

y3, pred_h = model.predict_proba(update_x, pred_h, 3)
y3 = [num2class[i] for i in np.argmax(y3.cpu().detach().numpy(), axis=1)]

t4 = model.generate_next(update_x, pred_h, 4)
update_x = torch.stack([t0,t1,t2,t3,t4,t4,t4,t4,t4,t4,t4],dim=1)

y4, pred_h = model.predict_proba(update_x, pred_h, 4)
y4 = [num2class[i] for i in np.argmax(y4.cpu().detach().numpy(), axis=1)]

t5 = model.generate_next(update_x, pred_h, 5)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t5,t5,t5,t5,t5],dim=1)

y5, pred_h = model.predict_proba(update_x, pred_h, 5)
y5 = [num2class[i] for i in np.argmax(y5.cpu().detach().numpy(), axis=1)]

t6 = model.generate_next(update_x, pred_h, 6)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t6,t6,t6,t6,t6],dim=1)

y6, pred_h = model.predict_proba(update_x, pred_h, 6)
y6 = [num2class[i] for i in np.argmax(y6.cpu().detach().numpy(), axis=1)]

t7 = model.generate_next(update_x, pred_h, 7)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t6,t7,t7,t7,t7],dim=1)

y7, pred_h = model.predict_proba(update_x, pred_h, 7)
y7 = [num2class[i] for i in np.argmax(y7.cpu().detach().numpy(), axis=1)]

t8 = model.generate_next(update_x, pred_h, 8)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t6,t7,t8,t8,t8],dim=1)

y8, pred_h = model.predict_proba(update_x, pred_h, 8)
y8 = [num2class[i] for i in np.argmax(y8.cpu().detach().numpy(), axis=1)]

t9 = model.generate_next(update_x, pred_h, 9)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t6,t7,t8,t9,t9],dim=1)

y9, pred_h = model.predict_proba(update_x, pred_h, 9)
y9 = [num2class[i] for i in np.argmax(y9.cpu().detach().numpy(), axis=1)]

t10 = model.generate_next(update_x, pred_h, 10)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t6,t7,t8,t9,t10],dim=1)

y10, pred_h = model.predict_proba(update_x, pred_h, 10)
y10 = [num2class[i] for i in np.argmax(y10.cpu().detach().numpy(), axis=1)]
In [61]:
t0 = t0.to(device).detach().numpy()
t1 = t1.to(device).detach().numpy()
t2 = t2.to(device).detach().numpy()
t3 = t3.to(device).detach().numpy()
t4 = t4.to(device).detach().numpy()
t5 = t5.to(device).detach().numpy()
t6 = t6.to(device).detach().numpy()
t7 = t7.to(device).detach().numpy()
t8 = t8.to(device).detach().numpy()
t9 = t9.to(device).detach().numpy()
t10 = t10.to(device).detach().numpy()
In [62]:
pred_X = np.vstack((t0,t1,t2,t3,t4,t5,t6,t7,t8,t9,t10))
pred_Y = np.concatenate((['Stem Cell']*batch_size, ['Stem Cell']*batch_size, y2, y3, y4, y5, y6, y7, y8, y9, y10)).tolist()
#pred_Y = np.concatenate((['Stem Cell']*batch_size, ['Stem Cell']*batch_size, y2, y3, y4, y5, y6, y7, y8, y9, y10)).tolist()
pred_T = ['t0']*batch_size + ['t1']*batch_size + ['t2']*batch_size + ['t3']*batch_size + ['t4']*batch_size + ['t5']*batch_size + ['t6']*batch_size + ['t7']*batch_size + ['t8']*batch_size+ ['t9']*batch_size + ['t10']*batch_size
In [63]:
pd.DataFrame(pred_X).to_csv("./pred_X_shr.csv", header=False)
pd.DataFrame(pred_Y).to_csv("./pred_Y_shr.csv")
pd.DataFrame(pred_T).to_csv("./pred_T_shr.csv")
In [4]:
adata = sc.read_csv('./pred_X_shr.csv', first_column_names=True)
pred_Y = pd.read_csv('./pred_Y_shr.csv')
pred_T = pd.read_csv('./pred_T_shr.csv')
adata.obs['celltype'] = pred_Y['0'].tolist()
adata.obs['timebin'] = pred_T['0'].tolist()
adata.var.index = gene_list['features']
In [5]:
sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, svd_solver='arpack')
sc.pp.neighbors(adata, n_neighbors=30, n_pcs=50)
sc.tl.leiden(adata)
sc.tl.paga(adata)
sc.pl.paga(adata) 
In [6]:
sc.tl.umap(adata, init_pos='paga')
In [7]:
adata.uns['celltype_colors'] = np.array([ "#9400d3","#5ab953", "#bfef45", "#008080", "#21B6A8", "#82b6ff", "#0000FF","#e6194b", "#9a6324", "#ffe119", "#ff9900", "#ffd4e3", "#9a6324", "#ddaa6f"], dtype=object)
adata.obs['celltype'] = pd.Categorical(adata.obs['celltype'], categories=["Stem Cell","Columella", "Lateral Root Cap", "Atrichoblast", "Trichoblast", "Cortex", "Endodermis", "Phloem", "Xylem", "Procambium","Pericycle"])
In [8]:
sc.pl.umap(adata, color=['celltype'])
/hpc/group/pbenfeylab/ch416/miniconda3/envs/pytorch-gpu/lib/python3.8/site-packages/scanpy/plotting/_tools/scatterplots.py:392: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  cax = scatter(
In [9]:
adata.uns['timebin_colors'] = np.array([ '#5E4FA2', '#3288BD', '#66C2A5', '#ABDDA4', '#E6F598', '#FFFFBF', '#FEE08B', '#FDAE61', '#F46D43', '#D53E4F','#9E0142'])
In [10]:
sc.pl.umap(adata, color=['timebin'])
/hpc/group/pbenfeylab/ch416/miniconda3/envs/pytorch-gpu/lib/python3.8/site-packages/scanpy/plotting/_tools/scatterplots.py:392: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  cax = scatter(
In [11]:
## AT1G79840
sc.pl.umap(adata, color='AT1G79840', title='AT1G79840 (GL2, Atrichoblast)')
## AT5G49270
sc.pl.umap(adata, color='AT5G49270', title='AT5G49270 (COBL9, Trichoblast)')
sc.pl.umap(adata, color='AT1G09750', title='AT1G09750 (CORTEX, Cortex)')
sc.pl.umap(adata, color='AT5G57620', title='AT5G57620 (MYB36, Endodermis)')
sc.pl.umap(adata, color='AT1G79430', title='AT1G79430 (APL, Phloem)')
sc.pl.umap(adata, color='AT1G71930', title='AT1G71930 (VND7, Xylem)')
In [12]:
sc.pl.umap(adata, color='AT1G71930', title='AT4G37650 (SHR)')
In [13]:
sc.pl.umap(adata, color='AT3G54220', title='AT3G54220 (SCR)')

Provide scr input data¶

In [71]:
classes = ['Columella', 'Lateral Root Cap', 'Phloem', 'Xylem', 'Procambium', 'Pericycle', 'Endodermis', 'Cortex', 'Atrichoblast', 'Trichoblast']
class2num = {c: i for (i, c) in enumerate(classes)}
num2class = {i: c for (i, c) in enumerate(classes)}
In [72]:
xm = torch.tensor(np.array(scrx), dtype=torch.float32)
x = torch.stack([xm,xm,xm,xm,xm,xm,xm,xm,xm,xm,xm],dim=1)
In [73]:
## Initialize hidden state
pred_h = model.init_hidden(batch_size)

# t0 and t1 prediction based on data from the first two time points
t0 = model.generate_current(x, pred_h, 0)
t1 = model.generate_next(x, pred_h, 1)
update_x = torch.stack([t0,t1,t1,t1,t1,t1,t1,t1,t1,t1,t1],dim=1)

## predict t1 label
y1, pred_h = model.predict_proba(update_x, pred_h, 1)
y1 = [num2class[i] for i in np.argmax(y1.cpu().detach().numpy(), axis=1)]

t2 = model.generate_next(update_x, pred_h, 2)
update_x = torch.stack([t0,t1,t2,t2,t2,t2,t2,t2,t2,t2,t2],dim=1)

y2, pred_h = model.predict_proba(update_x, pred_h, 2)
y2 = [num2class[i] for i in np.argmax(y2.cpu().detach().numpy(), axis=1)]

t3 = model.generate_next(update_x, pred_h, 3)
update_x = torch.stack([t0,t1,t2,t3,t3,t3,t3,t3,t3,t3,t3],dim=1)

y3, pred_h = model.predict_proba(update_x, pred_h, 3)
y3 = [num2class[i] for i in np.argmax(y3.cpu().detach().numpy(), axis=1)]

t4 = model.generate_next(update_x, pred_h, 4)
update_x = torch.stack([t0,t1,t2,t3,t4,t4,t4,t4,t4,t4,t4],dim=1)

y4, pred_h = model.predict_proba(update_x, pred_h, 4)
y4 = [num2class[i] for i in np.argmax(y4.cpu().detach().numpy(), axis=1)]

t5 = model.generate_next(update_x, pred_h, 5)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t5,t5,t5,t5,t5],dim=1)

y5, pred_h = model.predict_proba(update_x, pred_h, 5)
y5 = [num2class[i] for i in np.argmax(y5.cpu().detach().numpy(), axis=1)]

t6 = model.generate_next(update_x, pred_h, 6)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t6,t6,t6,t6,t6],dim=1)

y6, pred_h = model.predict_proba(update_x, pred_h, 6)
y6 = [num2class[i] for i in np.argmax(y6.cpu().detach().numpy(), axis=1)]

t7 = model.generate_next(update_x, pred_h, 7)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t6,t7,t7,t7,t7],dim=1)

y7, pred_h = model.predict_proba(update_x, pred_h, 7)
y7 = [num2class[i] for i in np.argmax(y7.cpu().detach().numpy(), axis=1)]

t8 = model.generate_next(update_x, pred_h, 8)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t6,t7,t8,t8,t8],dim=1)

y8, pred_h = model.predict_proba(update_x, pred_h, 8)
y8 = [num2class[i] for i in np.argmax(y8.cpu().detach().numpy(), axis=1)]

t9 = model.generate_next(update_x, pred_h, 9)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t6,t7,t8,t9,t9],dim=1)

y9, pred_h = model.predict_proba(update_x, pred_h, 9)
y9 = [num2class[i] for i in np.argmax(y9.cpu().detach().numpy(), axis=1)]

t10 = model.generate_next(update_x, pred_h, 10)
update_x = torch.stack([t0,t1,t2,t3,t4,t5,t6,t7,t8,t9,t10],dim=1)

y10, pred_h = model.predict_proba(update_x, pred_h, 10)
y10 = [num2class[i] for i in np.argmax(y10.cpu().detach().numpy(), axis=1)]
In [74]:
t0 = t0.to(device).detach().numpy()
t1 = t1.to(device).detach().numpy()
t2 = t2.to(device).detach().numpy()
t3 = t3.to(device).detach().numpy()
t4 = t4.to(device).detach().numpy()
t5 = t5.to(device).detach().numpy()
t6 = t6.to(device).detach().numpy()
t7 = t7.to(device).detach().numpy()
t8 = t8.to(device).detach().numpy()
t9 = t9.to(device).detach().numpy()
t10 = t10.to(device).detach().numpy()
In [75]:
pred_X = np.vstack((t0,t1,t2,t3,t4,t5,t6,t7,t8,t9,t10))
pred_Y = np.concatenate((['Stem Cell']*batch_size, ['Stem Cell']*batch_size, y2, y3, y4, y5, y6, y7, y8, y9, y10)).tolist()
#pred_Y = np.concatenate((['Stem Cell']*batch_size, ['Stem Cell']*batch_size, y2, y3, y4, y5, y6, y7, y8, y9, y10)).tolist()
pred_T = ['t0']*batch_size + ['t1']*batch_size + ['t2']*batch_size + ['t3']*batch_size + ['t4']*batch_size + ['t5']*batch_size + ['t6']*batch_size + ['t7']*batch_size + ['t8']*batch_size+ ['t9']*batch_size + ['t10']*batch_size
In [76]:
pd.DataFrame(pred_X).to_csv("./pred_X_scr.csv", header=False)
pd.DataFrame(pred_Y).to_csv("./pred_Y_scr.csv")
pd.DataFrame(pred_T).to_csv("./pred_T_scr.csv")
In [14]:
adata = sc.read_csv('./pred_X_scr.csv', first_column_names=True)
pred_Y = pd.read_csv('./pred_Y_scr.csv')
pred_T = pd.read_csv('./pred_T_scr.csv')
adata.obs['celltype'] = pred_Y['0'].tolist()
adata.obs['timebin'] = pred_T['0'].tolist()
adata.var.index = gene_list['features']
In [15]:
sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, svd_solver='arpack')
sc.pp.neighbors(adata, n_neighbors=30, n_pcs=50)
sc.tl.leiden(adata)
sc.tl.paga(adata)
sc.pl.paga(adata) 
In [16]:
sc.tl.umap(adata, init_pos='paga')
In [17]:
adata.uns['celltype_colors'] = np.array([ "#9400d3","#5ab953", "#bfef45", "#008080", "#21B6A8", "#82b6ff", "#0000FF","#e6194b", "#9a6324", "#ffe119", "#ff9900", "#ffd4e3", "#9a6324", "#ddaa6f"], dtype=object)
adata.obs['celltype'] = pd.Categorical(adata.obs['celltype'], categories=["Stem Cell","Columella", "Lateral Root Cap", "Atrichoblast", "Trichoblast", "Cortex", "Endodermis", "Phloem", "Xylem", "Procambium","Pericycle"])
In [18]:
sc.pl.umap(adata, color=['celltype'])
/hpc/group/pbenfeylab/ch416/miniconda3/envs/pytorch-gpu/lib/python3.8/site-packages/scanpy/plotting/_tools/scatterplots.py:392: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  cax = scatter(
In [19]:
adata.uns['timebin_colors'] = np.array([ '#5E4FA2', '#3288BD', '#66C2A5', '#ABDDA4', '#E6F598', '#FFFFBF', '#FEE08B', '#FDAE61', '#F46D43', '#D53E4F','#9E0142'])
In [20]:
sc.pl.umap(adata, color=['timebin'])
/hpc/group/pbenfeylab/ch416/miniconda3/envs/pytorch-gpu/lib/python3.8/site-packages/scanpy/plotting/_tools/scatterplots.py:392: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  cax = scatter(
In [21]:
## AT1G79840
sc.pl.umap(adata, color='AT1G79840', title='AT1G79840 (GL2, Atrichoblast)')
## AT5G49270
sc.pl.umap(adata, color='AT5G49270', title='AT5G49270 (COBL9, Trichoblast)')
sc.pl.umap(adata, color='AT1G09750', title='AT1G09750 (CORTEX, Cortex)')
sc.pl.umap(adata, color='AT5G57620', title='AT5G57620 (MYB36, Endodermis)')
sc.pl.umap(adata, color='AT1G79430', title='AT1G79430 (APL, Phloem)')
sc.pl.umap(adata, color='AT1G71930', title='AT1G71930 (VND7, Xylem)')
In [22]:
sc.pl.umap(adata, color='AT1G71930', title='AT4G37650 (SHR)')
In [23]:
sc.pl.umap(adata, color='AT3G54220', title='AT3G54220 (SCR)')
In [ ]: